import itertools
import json
import logging
from datetime import timedelta
from functools import cached_property
from typing import Any

import tenacity
from channels.db import database_sync_to_async
from channels.exceptions import DenyConnection, StopConsumer
from channels.generic.websocket import AsyncJsonWebsocketConsumer
from django.core.exceptions import ValidationError
from django.core.serializers.json import DjangoJSONEncoder
from django.db import models, transaction
from django.utils import timezone
from django.utils.crypto import get_random_string
from rest_framework.exceptions import ValidationError as RestFrameworkValidationError
from uvicorn.protocols.utils import ClientDisconnected

from sysreptor.pentests.collab.text_transformations import (
    EditorSelection,
    SelectionRange,
    Update,
    rebase_updates,
)
from sysreptor.pentests.models import (
    CollabClientInfo,
    CollabEvent,
    CollabEventType,
)
from sysreptor.pentests.models.collab import collab_context
from sysreptor.users.serializers import PentestUserSerializer
from sysreptor.utils.elasticapm import elasticapm_capture_websocket_transaction
from sysreptor.utils.history import history_context
from sysreptor.utils.utils import get_random_color

log = logging.getLogger(__name__)


class WebsocketConsumerBase(AsyncJsonWebsocketConsumer):
    last_permission_check_time = None
    initial_path = None
    related_id = None

    @property
    def user(self):
        if (user := self.scope.get('user')) and not user.is_anonymous:
            return user
        return None

    async def get_related_id(self):
        return None

    async def dispatch(self, message):
        try:
            if not message.get('type', '').startswith('websocket.') and not await self.check_permission(action='read', skip_on_recent_check=True):
                await self.close(code=4443)
                return

            with history_context(history_user=self.user):
                await super().dispatch(message)
        except (StopConsumer, ClientDisconnected):
            raise
        except Exception as ex:
            log.exception(ex)
            raise ex

    async def websocket_connect(self, message):
        async with elasticapm_capture_websocket_transaction(scope=self.scope, event={'type': 'websocket.connect'}):
            self.related_id = await self.get_related_id()

            # Set user.admin_permissions_enabled
            if self.user and self.scope.get('session', {}).get('admin_permissions_enabled'):
                self.user.admin_permissions_enabled = True

            with history_context(history_user=self.user):
                out = await super().websocket_connect(message)

            # Log connection
            user = '<none>'
            if self.user:
                user = self.user.username
            logging.info(f'CONNECT {self.scope['path']} (user={user})')

            return out

    async def websocket_receive(self, message):
        event = await self.decode_json(message.get('text', '{}'))
        if event.get('type') == 'ping':
            await self.send_json({'type': 'ping'})
            return

        async with elasticapm_capture_websocket_transaction(scope=self.scope, event=event):
            if not await self.check_permission(action='write', event=event):
                await self.close(code=4443)
                return

            try:
                with history_context(history_user=self.user):
                    return await super().websocket_receive(message)
            except ValidationError as ex:
                await self.send_json({
                    'type': 'error',
                    'message': ex.message,
                })
            except RestFrameworkValidationError as ex:
                await self.send_json({
                    'type': 'error',
                    'message': str(ex.detail),
                })

    async def websocket_disconnect(self, message):
        try:
            return await super().websocket_disconnect(message)
        finally:
            user = '<none>'
            if self.user:
                user = self.user.username
            logging.info(f'DISCONNECT {self.scope['path']} (user={user})')

    async def encode_json(self, content):
        return json.dumps(content, cls=DjangoJSONEncoder)

    @property
    def group_name(self) -> str:
        raise NotImplementedError()

    @cached_property
    def client_id(self) -> str:
        if client_id := self.scope.get('client_id'):
            return client_id
        elif self.user:
            return f'{self.user.id}/{get_random_string(8)}'
        else:
            return f'anonymous/{get_random_string(8)}'

    @cached_property
    def client_color(self) -> str:
        if self.user and self.user.color:
            return self.user.color
        return get_random_color()

    @database_sync_to_async
    def check_permission(self, skip_on_recent_check=False, action=None, **kwargs):
        # Skip permission check if it was done recently
        if skip_on_recent_check and self.last_permission_check_time and self.last_permission_check_time + timedelta(seconds=60) >= timezone.now():
            return True

        # Check if session is still valid
        session = self.scope.get('session')
        if self.user and (
            not session or not session.session_key or \
            session.expire_date < timezone.now() or \
            (not session.exists(session.session_key))
        ):
            return False

        # Check custom permissions
        res = self.has_permission(action=action, **kwargs)
        self.last_permission_check_time = timezone.now()
        return res

    def has_permission(self, **kwargs):
        return True

    async def delete_client_info(self):
        pass

    async def connect(self):
        if not await self.check_permission(action='connect') or not self.channel_layer:
            raise DenyConnection()

        await super().connect()
        await self.channel_layer.group_add(self.group_name, self.channel_name)

    async def disconnect(self, code=None):
        if self.channel_layer:
            await self.channel_layer.group_discard(self.group_name, self.channel_name)
        await super().disconnect(code=code)


class CollabConsumerBase(WebsocketConsumerBase):
    async def get_related_id(self):
        raise NotImplementedError()

    @database_sync_to_async
    def create_client_info(self):
        CollabClientInfo.objects.create(
            related_id=self.related_id,
            user=self.user,
            client_id=self.client_id,
            client_color=self.client_color,
            path=self.initial_path,
        )

    @database_sync_to_async
    def delete_client_info(self):
        CollabClientInfo.objects \
            .filter(client_id=self.client_id) \
            .delete()

    def filter_path(self, obj_or_qs):
        if not self.initial_path:
            return obj_or_qs

        if isinstance(obj_or_qs, models.QuerySet):
            return obj_or_qs.filter(path__startswith=self.initial_path)
        elif isinstance(obj_or_qs, dict) and (obj_or_qs.get('path') or '').startswith(self.initial_path):
            return obj_or_qs
        return None

    def get_client_infos(self):
        clients = CollabClientInfo.objects \
            .filter(related_id=self.related_id) \
            .select_related('user')
        clients = self.filter_path(clients)

        return [{
            'client_id': c.client_id,
            'client_color': c.client_color,
            'user': PentestUserSerializer(c.user).data if c.user else None,
            'path': c.path,
        } for c in clients]

    async def get_connect_message(self):
        return {
            'type': CollabEventType.CONNECT,
            'client_id': self.client_id,
            'path': self.initial_path,
            'client': {
                'client_id': self.client_id,
                'client_color': self.client_color,
                'user': PentestUserSerializer(self.user).data if self.user else None,
            },
        }

    async def get_disconnect_message(self):
        return {
            'type': CollabEventType.DISCONNECT,
            'client_id': self.client_id,
            'path': self.initial_path,
        }

    async def connect(self):
        await super().connect()

        await self.create_client_info()
        if initial_msg := await self.get_initial_message():
            await self.send_json(initial_msg)

        if connect_msg := await self.get_connect_message():
            await self.send_colllab_event(connect_msg)

    async def disconnect(self, code=None):
        await self.delete_client_info()
        if disconnect_msg := await self.get_disconnect_message():
            await self.send_colllab_event(disconnect_msg)
        await super().disconnect(code=code)

    async def dispatch(self, message):
        try:
            await super().dispatch(message=message)
        except Exception:
            await self.delete_client_info()
            raise

    async def send_colllab_event(self, event):
        if not event or not self.channel_layer:
            return
        elif isinstance(event, CollabEvent):
            await self.channel_layer.group_send(self.group_name, {
                'type': 'collab_event',
                'id': str(event.id),
                'path': event.path,
            })
        else:
            await self.channel_layer.group_send(self.group_name, {
                'type': 'collab_event',
                'path': event.get('path'),
                'event': event,
            })

    async def collab_event(self, event):
        if not self.filter_path(event):
            return

        if event.get('id'):
            # Retry fetching event from DB: DB transactions can cause the channels event to arrive before event data is commited to the DB
            @tenacity.retry(retry=tenacity.retry_if_exception_type(CollabEvent.DoesNotExist), stop=tenacity.stop_after_delay(1), wait=tenacity.wait_fixed(0.1))
            @database_sync_to_async
            def get_collab_event(id):
                return CollabEvent.objects.get(id=id)

            collab_event = await get_collab_event(event['id'])
            event_data = collab_event.to_dict()
        elif isinstance(event.get('event'), dict):
            event_data = event['event']
        else:
            return

        for e in await self.filter_events([event_data]):
            await self.send_json(e)

    async def filter_events(self, events: list[dict]):
        return events


class CollabUpdateTextMixin:
    def get_object_for_update(self, content):
        raise NotImplementedError()

    def perform_update_text(self, obj, path, changes, definition, content):
        raise NotImplementedError()

    @database_sync_to_async
    @transaction.atomic()
    def collab_update_text(self, content):
        obj, path, definition = self.get_object_for_update(content)

        version = content['version']
        # TODO: reject updates for versions that are too old
        # * check if version is too old and if there are updates in between
        # * simple timestamp comparison is not enough, because when there were no updates in between, the version is still valid
        # * checking version < note.version is not enough, because of concurrent updates (e.g. old version, update1 succeeds, update2 fails because of updated version)

        # Rebase updates
        over_updates = CollabEvent.objects \
            .filter(related_id=self.related_id) \
            .filter(path=content['path']) \
            .filter(type=CollabEventType.UPDATE_TEXT) \
            .filter(version__gt=version) \
            .order_by('version')
        updates, selection = rebase_updates(
            updates=[Update.from_dict(u | {'client_id': self.client_id, 'version': version}) for u in content.get('updates', [])],
            selection=EditorSelection.from_dict(content['selection']) if content.get('selection') else None,
            over=list(itertools.chain(*[[
                Update.from_dict(u | {'client_id': e.client_id, 'version': version})
                for u in e.data.get('updates', [])] for e in over_updates])),
        )
        if not updates:
            raise ValidationError('No updates')

        # Update in DB
        changes = updates[0].changes
        for u in updates[1:]:
            changes = changes.compose(u.changes)

        with collab_context(prevent_events=True):
            obj, event_data = self.perform_update_text(obj=obj, path=path, changes=changes, definition=definition, content=content)

        # Update client info
        CollabClientInfo.objects \
            .filter(client_id=self.client_id) \
            .update(path=content['path'])

        # Store OT event in DB
        return CollabEvent.objects.create(
            related_id=self.related_id,
            path=content['path'],
            type=CollabEventType.UPDATE_TEXT,
            created=obj.updated,
            version=obj.updated.timestamp(),
            client_id=self.client_id,
            data={
                'updates': [u.to_dict() for u in updates],
                **({'selection': selection.to_dict()} if selection else {}),
                **(event_data or {}),
            },
        )


class CollabUpdateKeyMixin:
    def get_object_for_update(self, content) -> tuple[Any, list[str]]:
        raise NotImplementedError()

    def perform_update_key(self, obj, path, value, definition, content):
        raise NotImplementedError()

    @database_sync_to_async
    @transaction.atomic()
    def collab_update_key(self, content):
        obj, path, definition = self.get_object_for_update(content)

        # Update in DB
        with collab_context(prevent_events=True):
            obj, event_data = self.perform_update_key(obj=obj, path=path, value=content['value'], definition=definition, content=content)

        # Update client info
        if content.get('update_awareness', False):
            CollabClientInfo.objects \
                .filter(client_id=self.client_id) \
                .update(path=content['path'])

        # Store OT event in DB
        return CollabEvent.objects.create(
            related_id=self.related_id,
            path=content['path'],
            type=CollabEventType.UPDATE_KEY,
            created=obj.updated,
            version=obj.updated.timestamp(),
            client_id=self.client_id,
            data={
                'value': content['value'],
                **(event_data or {}),
            },
        )


def rebase_selection(selection: EditorSelection|SelectionRange, over_events, version):
    over_events = over_events \
        .filter(type=CollabEventType.UPDATE_TEXT) \
        .filter(version__gt=version) \
        .order_by('version')
    over_updates = list(itertools.chain(*[[
        Update.from_dict(u | {'client_id': e.client_id, 'version': e.version})
        for u in e.data.get('updates', [])] for e in over_events]))
    version = max([e.version for e in over_updates] + [version])

    for u in over_updates:
        selection = selection.map(u.changes)
    return selection, version


class CollabUpdateAwarenessMixin:
    @database_sync_to_async
    def collab_update_awareness(self, content):
        path = content.get('path') or self.initial_path
        version = content['version']

        selection = None
        if path and content.get('selection'):
            try:
                selection, version = rebase_selection(
                    selection=EditorSelection.from_dict(content['selection']),
                    over_events=CollabEvent.objects.filter(related_id=self.related_id).filter(path=path),
                    version=version,
                )
            except ValueError:
                # Reset selection position, keep path info
                selection = None

        # Update client info
        CollabClientInfo.objects \
            .filter(client_id=self.client_id) \
            .update(path=path)

        return {
            'type': CollabEventType.AWARENESS,
            'path': path,
            'client_id': self.client_id,
            **({'selection': selection.to_dict()} if selection else {}),
        }


class GenericCollabMixin(CollabUpdateKeyMixin, CollabUpdateTextMixin, CollabUpdateAwarenessMixin):
    pass
